from math import log2, sqrt
import torch
from torch import nn, einsum
import torch.nn.functional as F
import numpy as np

from axial_positional_embedding import AxialPositionalEmbedding
from einops import rearrange

from dalle_pytorch import distributed_utils
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
from dalle_pytorch.transformer import Transformer, DivideMax

# helpers
import torch
import torch.nn as nn

import sys
# sys.path.insert(0, '/home/tiangel/DALLE_3D/Learning-to-Group')
# from core.nn import SharedMLP
# from core.nn.init import xavier_uniform, set_bn
# from shaper.models.pointnet2.modules import PointNetSAModule, PointnetFPModule
# helpers
from IPython import embed
from pytorch3d.loss import chamfer_distance
from emd_3d.emd_module import emdModule
# sys.path.insert(0,'/home/tiangel/DALLE_newest/UnsupervisedPointCloudReconstruction')
# from UnsupervisedPointCloudReconstruction.model import DGCNN_Seg_Encoder, DGCNN_Seg_Encoder2,DGCNN_Seg_Encoder3,FoldNet_Encoder, FoldNet_Decoder, FoldNet_Decoder3, FoldNet_Decoder4, FoldNet_Decoder5, FoldNet_Decoder6

# sys.path.insert(0, '/home/tiangel/DALLE_3D/Learning-to-Group')
sys.path.insert(0, '/home/tiangel/Learning-to-Group')
from shaper.models.pointnet2.modules import PointNetSAModule, PointnetFPModule
from partnet.utils.torch_pc import normalize_points as normalize_points_torch
from core.nn.init import xavier_uniform, set_bn

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

class always():
    def __init__(self, val):
        self.val = val
    def __call__(self, x, *args, **kwargs):
        return self.val

def is_empty(t):
    return t.nelement() == 0

def masked_mean(t, mask, dim = 1):
    t = t.masked_fill(~mask[:, :, None], 0.)
    return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]

def prob_mask_like(shape, prob, device):
    return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

def set_requires_grad(model, value):
    for param in model.parameters():
        param.requires_grad = value

def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# sampling helpers

def log(t, eps = 1e-20):
    return torch.log(t + eps)

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

def to_contiguous(tensor):
    if tensor.is_contiguous():
        return tensor
    else:
        return tensor.contiguous()

class SharedEmbedding(nn.Embedding):
    def __init__(self, linear, start_index, end_index, **kwargs):
        super().__init__(end_index - start_index, linear.weight.shape[1], **kwargs)
        del self.weight

        self.linear = linear
        self.start_index = start_index
        self.end_index = end_index

    def forward(self, input):
        return F.embedding(
            input, self.linear.weight[self.start_index:self.end_index], self.padding_idx, self.max_norm,
            self.norm_type, self.scale_grad_by_freq, self.sparse)


class VQVAE_Decoder(nn.Module):
    def __init__(self, feat_dims, codebook_dim=512, radius=0.3, final_dim=2048):
        super(VQVAE_Decoder, self).__init__()
        self.dim = codebook_dim
        self.folding1 = nn.Sequential(
            nn.Conv1d(self.dim, 2*self.dim, 1),
            nn.BatchNorm1d(2*self.dim),
            nn.ReLU(),
            nn.Conv1d(2*self.dim, 2*self.dim, 1),
            nn.BatchNorm1d(2*self.dim),
            nn.ReLU(),
            nn.Conv1d(2*self.dim, self.dim, 1),
        )
         
        self.folding2 = nn.Sequential(
            nn.Conv1d(self.dim, 2*self.dim, 1),
            nn.BatchNorm1d(2*self.dim),
            nn.ReLU(),
            nn.Conv1d(2*self.dim, 2*self.dim, 1),
            nn.BatchNorm1d(2*self.dim),
            nn.ReLU(),
            nn.Conv1d(2*self.dim, self.dim, 1),
        )
        self.folding3 = nn.Sequential(
            nn.Conv1d(self.dim, 2*self.dim, 1),
            nn.BatchNorm1d(2*self.dim),
            nn.ReLU(),
            nn.Conv1d(2*self.dim, 2*self.dim, 1),
            nn.BatchNorm1d(2*self.dim),
            nn.ReLU(),
            nn.Conv1d(2*self.dim, 512, 1),
        )
        self.our_end = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(512, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Conv1d(1024, final_dim*3, 1)
        )

    def forward(self, x):
        folding_result1 = self.folding1(x)           # (batch_size, 3, num_points)
        x = x+folding_result1
        folding_result2 = self.folding2(x)           # (batch_size, 3, num_points)
        x = x+folding_result2
        folding_result3 = self.folding3(x)           # (batch_size, 3, num_points)
        max_feature = torch.max(folding_result3, -1, keepdim=True)[0]
        output = self.our_end(max_feature)

        return output          # (batch_size, num_points ,3)

class VQVAE_Decoder_wide8(nn.Module):
    def __init__(self, feat_dims, codebook_dim=512, radius=0.3, final_dim=2048):
        super(VQVAE_Decoder_wide8, self).__init__()
        self.dim = codebook_dim
        self.folding1 = nn.Sequential(
            nn.Conv1d(self.dim, 8*self.dim, 1),
            nn.BatchNorm1d(8*self.dim),
            nn.ReLU(),
            nn.Conv1d(8*self.dim, 8*self.dim, 1),
            nn.BatchNorm1d(8*self.dim),
            nn.ReLU(),
            nn.Conv1d(8*self.dim, self.dim, 1),
        )
         
        self.folding2 = nn.Sequential(
            nn.Conv1d(self.dim, 8*self.dim, 1),
            nn.BatchNorm1d(8*self.dim),
            nn.ReLU(),
            nn.Conv1d(8*self.dim, 8*self.dim, 1),
            nn.BatchNorm1d(8*self.dim),
            nn.ReLU(),
            nn.Conv1d(8*self.dim, self.dim, 1),
        )
        self.folding3 = nn.Sequential(
            nn.Conv1d(self.dim, 8*self.dim, 1),
            nn.BatchNorm1d(8*self.dim),
            nn.ReLU(),
            nn.Conv1d(8*self.dim, 8*self.dim, 1),
            nn.BatchNorm1d(8*self.dim),
            nn.ReLU(),
            nn.Conv1d(8*self.dim, 512, 1),
        )
        self.our_end = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(512, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Conv1d(1024, final_dim*3, 1)
        )

    def forward(self, x):
        folding_result1 = self.folding1(x)           # (batch_size, 3, num_points)
        x = x+folding_result1
        folding_result2 = self.folding2(x)           # (batch_size, 3, num_points)
        x = x+folding_result2
        folding_result3 = self.folding3(x)           # (batch_size, 3, num_points)
        max_feature = torch.max(folding_result3, -1, keepdim=True)[0]
        output = self.our_end(max_feature)

        return output          # (batch_size, num_points ,3)
class VQVAE_Decoder_wide5(nn.Module):
    def __init__(self, feat_dims, codebook_dim=512, radius=0.3, final_dim=2048):
        super(VQVAE_Decoder_wide5, self).__init__()
        self.dim = codebook_dim
        self.folding1 = nn.Sequential(
            nn.Conv1d(self.dim, 5*self.dim, 1),
            nn.BatchNorm1d(5*self.dim),
            nn.ReLU(),
            nn.Conv1d(5*self.dim, 5*self.dim, 1),
            nn.BatchNorm1d(5*self.dim),
            nn.ReLU(),
            nn.Conv1d(5*self.dim, self.dim, 1),
        )
         
        self.folding2 = nn.Sequential(
            nn.Conv1d(self.dim, 5*self.dim, 1),
            nn.BatchNorm1d(5*self.dim),
            nn.ReLU(),
            nn.Conv1d(5*self.dim, 5*self.dim, 1),
            nn.BatchNorm1d(5*self.dim),
            nn.ReLU(),
            nn.Conv1d(5*self.dim, self.dim, 1),
        )
        self.folding3 = nn.Sequential(
            nn.Conv1d(self.dim, 5*self.dim, 1),
            nn.BatchNorm1d(5*self.dim),
            nn.ReLU(),
            nn.Conv1d(5*self.dim, 5*self.dim, 1),
            nn.BatchNorm1d(5*self.dim),
            nn.ReLU(),
            nn.Conv1d(5*self.dim, 512, 1),
        )
        self.our_end = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(512, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Conv1d(1024, final_dim*3, 1)
        )

    def forward(self, x):
        folding_result1 = self.folding1(x)           # (batch_size, 3, num_points)
        x = x+folding_result1
        folding_result2 = self.folding2(x)           # (batch_size, 3, num_points)
        x = x+folding_result2
        folding_result3 = self.folding3(x)           # (batch_size, 3, num_points)
        max_feature = torch.max(folding_result3, -1, keepdim=True)[0]
        output = self.our_end(max_feature)

        return output          # (batch_size, num_points ,3)
class VQVAE_Decoder_depth5(nn.Module):
    def __init__(self, feat_dims, codebook_dim=512, radius=0.3, final_dim=2048):
        super(VQVAE_Decoder_depth5, self).__init__()
        self.dim = codebook_dim
        self.folding1 = nn.Sequential(
            nn.Conv1d(self.dim, 4*self.dim, 1),
            nn.BatchNorm1d(4*self.dim),
            nn.ReLU(),
            nn.Conv1d(4*self.dim, 4*self.dim, 1),
            nn.BatchNorm1d(4*self.dim),
            nn.ReLU(),
            nn.Conv1d(4*self.dim, self.dim, 1),
        )
        self.folding2 = nn.Sequential(
            nn.Conv1d(self.dim, 4*self.dim, 1),
            nn.BatchNorm1d(4*self.dim),
            nn.ReLU(),
            nn.Conv1d(4*self.dim, 4*self.dim, 1),
            nn.BatchNorm1d(4*self.dim),
            nn.ReLU(),
            nn.Conv1d(4*self.dim, self.dim, 1),
        )
        self.folding3 = nn.Sequential(
            nn.Conv1d(self.dim, 4*self.dim, 1),
            nn.BatchNorm1d(4*self.dim),
            nn.ReLU(),
            nn.Conv1d(4*self.dim, 4*self.dim, 1),
            nn.BatchNorm1d(4*self.dim),
            nn.ReLU(),
            nn.Conv1d(4*self.dim, self.dim, 1),
        )
        self.folding4 = nn.Sequential(
            nn.Conv1d(self.dim, 4*self.dim, 1),
            nn.BatchNorm1d(4*self.dim),
            nn.ReLU(),
            nn.Conv1d(4*self.dim, 4*self.dim, 1),
            nn.BatchNorm1d(4*self.dim),
            nn.ReLU(),
            nn.Conv1d(4*self.dim, self.dim, 1),
        )
        self.folding5 = nn.Sequential(
            nn.Conv1d(self.dim, 4*self.dim, 1),
            nn.BatchNorm1d(4*self.dim),
            nn.ReLU(),
            nn.Conv1d(4*self.dim, 4*self.dim, 1),
            nn.BatchNorm1d(4*self.dim),
            nn.ReLU(),
            nn.Conv1d(4*self.dim, 512, 1),
        )
        self.our_end = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(512, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Conv1d(1024, final_dim*3, 1)
        )

    def forward(self, x):
        folding_result1 = self.folding1(x)           # (batch_size, 3, num_points)
        x = x+folding_result1
        folding_result2 = self.folding2(x)           # (batch_size, 3, num_points)
        x = x+folding_result2
        folding_result3 = self.folding3(x)           # (batch_size, 3, num_points)
        x = x+folding_result3
        folding_result4 = self.folding4(x)           # (batch_size, 3, num_points)
        x = x+folding_result4
        folding_result5 = self.folding5(x)           # (batch_size, 3, num_points)
        max_feature = torch.max(folding_result5, -1, keepdim=True)[0]
        output = self.our_end(max_feature)

        return output          # (batch_size, num_points ,3)

class VQVAE_Decoder_transf(nn.Module):
    def __init__(self, feat_dims, codebook_dim=512, radius=0.3, final_dim=2048):
        super(VQVAE_Decoder_transf, self).__init__()
        self.transf = Transformer(
            dim = feat_dims,
            seq_len = codebook_dim,
            depth=3,
            heads=8,
            dim_head=64,
            rotary_emb= False,
            attn_types=('full',),
        )
        self.our_end = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(512, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Conv1d(1024, final_dim*3, 1)
        )

    def forward(self, x):
        x = self.transf(x)
        max_feature = torch.max(x, -1, keepdim=True)[0]
        output = self.our_end(max_feature)

        return output          # (batch_size, num_points ,3)

class ResBlock(nn.Module):
    def __init__(self, chan):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(chan, chan, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(chan, chan, 1)
        )

    def forward(self, x):
        return self.net(x) + x

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
        super(VectorQuantizer, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        # [8, 512, 128, 1]
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        # [8, 128, 1, 512]
        input_shape = inputs.shape
        
        # Flatten input
        # flat_input = inputs.view(-1, self._embedding_dim)
        # [8*128, 512]
        flat_input = inputs.view(-1, self._num_embeddings)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        # [4, 128, 1, 512]
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        #quantized.permute: [4, 512, 128, 1]
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encoding_indices.reshape(inputs.shape[0],-1)

class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25, decay=0.99, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.normal_()
        self._commitment_cost = commitment_cost
        
        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()
        
        self._decay = decay
        self._epsilon = epsilon

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)
            
            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)
            
            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
            
            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        loss = self._commitment_cost * e_latent_loss
        
        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

class DiscreteVAE(nn.Module):
    def __init__(
        self,
        image_size = 256,
        num_tokens = 512,
        codebook_dim = 512,
        num_layers = 3,
        num_resnet_blocks = 0,
        hidden_dim = 64,
        channels = 3,
        smooth_l1_loss = False,
        temperature = 0.9,
        straight_through = False,
        kl_div_loss_weight = 0.,
        normalization = ((0.5,) * 3, (0.5,) * 3),
        dim1 = 16,
        dim2 = 32,
        radius= 0.3,
        final_points = 16,
        final_dim = 2048,
        vae_type = 1,
        vae_encode_type = 1,
    ):
        super().__init__()
        assert log2(image_size).is_integer(), 'image size must be a power of 2'
        assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
        has_resblocks = num_resnet_blocks > 0

        self.image_size = image_size
        self.num_tokens = num_tokens
        self.num_layers = num_layers
        self.temperature = temperature
        self.straight_through = straight_through
        # self.codebook = nn.Embedding(num_tokens, codebook_dim)
        self.codebook_dim = codebook_dim
        self.final_dim = final_dim

        hdim = hidden_dim

        enc_chans = [hidden_dim] * num_layers
        dec_chans = list(reversed(enc_chans))

        enc_chans = [channels, *enc_chans]

        dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
        dec_chans = [dec_init_chan, *dec_chans]

        enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))

        # enc_layers = []
        # dec_layers = []

        # for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
            # enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
            # dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))
        #
        # for _ in range(num_resnet_blocks):
            # dec_layers.insert(0, ResBlock(dec_chans[1]))
            # enc_layers.append(ResBlock(enc_chans[-1]))
        # 
        # if num_resnet_blocks > 0:
            # dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1))
# 
        # enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1))
        # dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))
# 
        # self.encoder = nn.Sequential(*enc_layers)
        # self.decoder = nn.Sequential(*dec_layers)

        self.final_points = final_points
        # self.quantize_layer = VectorQuantizerEMA(num_tokens, codebook_dim)
        self.quantize_layer = VectorQuantizer(num_tokens, codebook_dim)
        # self.encoder = DGCNN_Seg_Encoder(num_tokens, self.final_points)
        # self.encoder = DGCNN_Seg_Encoder2(num_tokens, self.final_points)
        # self.encoder = DGCNN_Seg_Encoder3(num_tokens, self.final_points)
        # self.decoder = FoldNet_Decoder(self.final_points, dim1, dim2, radius)
        # self.decoder = FoldNet_Decoder2(self.final_points, dim1, dim2, radius)
        # self.decoder = FoldNet_Decoder3(self.final_points, dim1, dim2, radius)
        # self.decoder = FoldNet_Decoder4(self.final_points, dim1, dim2, radius, final_dim)
        # self.decoder = FoldNet_Decoder5(self.final_points, dim1, dim2, radius, final_dim)
        # self.decoder = FoldNet_Decoder6(self.final_points, dim1, dim2, radius, final_dim)
        if vae_type == 1:
            self.decoder = VQVAE_Decoder(self.final_points, codebook_dim, radius, final_dim)
        elif vae_type == 2:
            self.decoder = VQVAE_Decoder_depth5(self.final_points, codebook_dim, radius, final_dim)
        elif vae_type == 3:
            self.decoder = VQVAE_Decoder_wide8(self.final_points, codebook_dim, radius, final_dim)
        elif vae_type == 4:
            self.decoder = VQVAE_Decoder_wide5(self.final_points, codebook_dim, radius, final_dim)
        # self.decoder = VQVAE_Decoder_transf(self.final_points, codebook_dim, radius, final_dim)

        # num_centroids=(128, self.final_points)

        #num_centroids=(256, self.final_points)
        #radius=(0.2, 0.4)
        #num_neighbours=(64, 512)
        #sa_channels=((256, 256), (256, num_tokens))

        in_channels = 6
        if vae_encode_type == 1:
            num_centroids=(512, self.final_points)
            radius=(0.1, 0.4)
            num_neighbours=(64, 512)
            sa_channels=((512, 512), (512, num_tokens))
        elif vae_encode_type == 2:
            num_centroids=(1024, 512, self.final_points)
            radius=(0.1, 0.2, 0.4)
            num_neighbours=(32, 64, 256)
            sa_channels=((256, 256), (256, 256), (512, num_tokens))
        elif vae_encode_type == 3:
            num_centroids=(512, self.final_points)
            radius=(0.1, 0.3)
            num_neighbours=(64, 256)
            sa_channels=((512, 512), (512, num_tokens))
        else:
            NameError('unsuppoprt encoder type')
        use_xyz=True
        num_sa_layers = len(num_centroids)

        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_sa_layers):
            sa_module = PointNetSAModule(in_channels=feature_channels,
                                         mlp_channels=sa_channels[ind],
                                         num_centroids=num_centroids[ind],
                                         radius=radius[ind],
                                         num_neighbours=num_neighbours[ind],
                                         use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_channels[ind][-1]
        self.reset_parameters()

        self.loss_emd = emdModule()

        self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
        self.kl_div_loss_weight = kl_div_loss_weight

        # take care of normalization within class
        self.normalization = normalization

        self._register_external_parameters()

    def reset_parameters(self):
        for sa_module in self.sa_modules:
            sa_module.reset_parameters(xavier_uniform)
        # self.mlp_seg.reset_parameters(xavier_uniform)
        set_bn(self, momentum=0.01)

    def _register_external_parameters(self):
        """Register external parameters for DeepSpeed partitioning."""
        if (
                not distributed_utils.is_distributed
                or not distributed_utils.using_backend(
                    distributed_utils.DeepSpeedBackend)
        ):
            return

        deepspeed = distributed_utils.backend.backend_module
        # deepspeed.zero.register_external_parameter(self, self.codebook.weight)

    def norm(self, images):
        if not exists(self.normalization):
            return images

        means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
        means, stds = map(lambda t: rearrange(t, 'c -> () c () ()'), (means, stds))
        images = images.clone()
        images.sub_(means).div_(stds)
        return images

    @torch.no_grad()
    @eval_decorator
    def get_codebook_indices(self, images):
        # logits = self(images, return_logits = True)
        # codebook_indices = logits.argmax(dim = 1).flatten(1)
        # return codebook_indices

        logits = self(images, return_logits = True)
        # _, sampled, _, _ = self.quantize_layer(logits.unsqueeze(-1))
        _, _, _, indices = self.quantize_layer(logits.unsqueeze(-1))
        return indices


    def decode(
        self,
        indices
    ):
        # image_embeds = self.codebook(img_seq)
        # b, n, d = image_embeds.shape
        # h = w = int(sqrt(n))

        # image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w)
        # images = self.decoder(image_embeds)
        # pcs = self.decoder(image_embeds.reshape(image_embeds.shape[0],-1)).reshape(image_embeds.shape[0],-1,3)
        # pcs = self.decoder(image_embeds.transpose(1,2)).reshape(image_embeds.shape[0],-1,3)
        flat_indices = indices.reshape(-1,1)
        encodings = torch.zeros(flat_indices.shape[0], self.quantize_layer._num_embeddings, device=indices.device)
        encodings.scatter_(1, flat_indices, 1)
        quantized = torch.matmul(encodings, self.quantize_layer._embedding.weight).view(indices.shape[0],indices.shape[1],1,-1).permute(0,3,1,2)
        pcs = self.decoder(quantized.squeeze(-1)).reshape(quantized.shape[0],-1,3)
        # return images
        return pcs

    def get_encoding(self, images):
        logits = self(images, return_logits = True)
        # _, sampled, _, _ = self.quantize_layer(logits.unsqueeze(-1))
        _, quantized, _, _ = self.quantize_layer(logits.unsqueeze(-1))
        return quantized

    def forward(
        self,
        img,
        return_loss = False,
        return_recons = False,
        return_logits = False,
        return_detailed_loss = False,
        temp = None,
        epoch = 0,
    ):
        device, num_tokens, image_size, kl_div_loss_weight = img.device, self.num_tokens, self.image_size, self.kl_div_loss_weight
        # assert img.shape[-1] == image_size and img.shape[-2] == image_size, f'input must have the correct image size {image_size}'

        # img = self.norm(img)

        points = img.transpose(1,2)
        # logits = self.encoder(img)
        xyz = points
        feature = points
        for sa_module in self.sa_modules:
            # xyz, feature, group_xyz = sa_module(xyz, feature)
            xyz, feature = sa_module(xyz, feature)

        if return_logits:
            return feature # return logits for getting hard image indices for DALL-E training

        #temp = default(temp, self.temperature)
        ## soft_one_hot = F.gumbel_softmax(logits, tau = temp, dim = 1, hard = self.straight_through)
        #soft_one_hot = F.gumbel_softmax(feature, tau = temp, dim = 1, hard = self.straight_through)
        ## sample.shape: [16, 1024, 128, 1]
        #sampled = einsum('b n h w, n d -> b d h w', soft_one_hot.unsqueeze(-1), self.codebook.weight)
        vq_loss, sampled, perplexity, _ = self.quantize_layer(feature.unsqueeze(-1))
        
        out = self.decoder(sampled.squeeze(-1))

        # single_part.shape: [16*128, 1024, 1]
        # single_part = sampled.transpose(1,2).reshape(-1, self.codebook_dim, 1)
        # part_out.shape: [16*128, 6144, 1]
        # part_out = self.decoder(single_part)
        # part_gt = group_xyz.transpose(1,2).transpose(2,3).reshape(part_out.shape[0],-1,3)
        # part_gt = normalize_points_torch(part_gt)


        if not return_loss:
            return out

        # reconstruction loss

        # recon_loss = self.loss_fn(img, out)
        points = points.transpose(1,2)
        idx_arr = np.arange(points.shape[1])
        np.random.shuffle(idx_arr)
        # recon_loss = 1*chamfer_distance(points[:,idx_arr[:2048],:], out.reshape(-1,2048,3))[0]
        recon_loss = 50*chamfer_distance(points[:,idx_arr[:self.final_dim],:], out.reshape(-1,self.final_dim,3))[0]
        cd_loss = recon_loss.clone()
        if len(idx_arr) < self.final_dim:
            emd_loss = torch.Tensor([0]).cuda()
            emd_loss.requires_grad = True
        else:
            recon_loss += 20*self.loss_emd(points[:,idx_arr[:self.final_dim],:], out.reshape(-1,self.final_dim,3),0.02, 100)[0].mean(1).mean()
            # recon_loss += 1*self.loss_emd(points[:,idx_arr[:2048],:], out.reshape(-1,2048,3),0.02, 100)[0].mean(1).mean()
            emd_loss = recon_loss - cd_loss
        # part_cd_loss = 10*chamfer_distance(part_gt, part_out.reshape(-1,2048,3))[0]
        # part_cd_loss = 1*chamfer_distance(part_gt, part_out.reshape(-1,2048,3))[0]
        # recon_loss += part_cd_loss

        # kl divergence

        # logits = rearrange(logits, 'b n h w -> b (h w) n')
        #log_qy = F.log_softmax(feature.transpose(1,2).reshape(-1,num_tokens), dim = -1)
        #log_uniform = torch.log(torch.tensor([1. / num_tokens], device = device))
        #kl_div = F.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target = True)

        # loss = recon_loss + (kl_div * kl_div_loss_weight)
        #if epoch < 10:
        #    loss = recon_loss + 0.001 * kl_div
        #else:
        #    loss = recon_loss + (kl_div * kl_div_loss_weight)
        loss = recon_loss + vq_loss
        # print('recon_loss:%.3f = cd_loss %.3f + emd_loss %.3f, part_cd_loss: %.3f, vq_loss:%.3f'%(recon_loss, cd_loss, emd_loss, part_cd_loss, vq_loss))
        print('recon_loss:%.3f = cd_loss %.3f + emd_loss %.3f, vq_loss:%.3f, perplexity:%.3f'%(recon_loss, cd_loss, emd_loss, vq_loss, perplexity))

        if not return_recons:
            return loss

        if not return_detailed_loss:
            return loss, out, perplexity
        else:
            return cd_loss, emd_loss, vq_loss, out, perplexity

class VQProgram(nn.Module):
    def __init__(
        self,
        depth = 3,
        dim = 512,
        heads = 8,
        dim_head = 64
    ):
        super().__init__()
        self.transformer = nn.Sequential(
            Transformer(
                dim = 512,
                seq_len=128,
                depth=3,
                heads=8,
                dim_head=64,
                rotary_emb= False,
                attn_types=('full',),
            ),
            nn.LayerNorm(512),
            nn.Linear(512,32)
        )
        self.end = nn.Sequential(
            nn.ReLU(),
            nn.Linear(128*32, 30*29)
        )

    def forward(self, x):
        x = self.transformer(x)
        x = x.reshape(x.shape[0],-1)
        x = self.end(x)
        return x

class DiscretePGVAE(nn.Module):
    def __init__(
        self,
        image_size = 256,
        num_tokens = 512,
        codebook_dim = 512,
        num_layers = 3,
        num_resnet_blocks = 0,
        hidden_dim = 64,
        channels = 3,
        smooth_l1_loss = False,
        temperature = 0.9,
        straight_through = False,
        kl_div_loss_weight = 0.,
        normalization = ((0.5,) * 3, (0.5,) * 3),
        dim1 = 16,
        dim2 = 32,
        radius= 0.3,
        final_points = 16,
        final_depth = 3,
    ):
        super().__init__()
        assert log2(image_size).is_integer(), 'image size must be a power of 2'
        assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
        has_resblocks = num_resnet_blocks > 0

        self.image_size = image_size
        self.num_tokens = num_tokens
        self.num_layers = num_layers
        self.temperature = temperature
        self.straight_through = straight_through
        # self.codebook = nn.Embedding(num_tokens, codebook_dim)
        self.codebook_dim = codebook_dim
        self.final_dim = final_depth

        hdim = hidden_dim

        enc_chans = [hidden_dim] * num_layers
        dec_chans = list(reversed(enc_chans))

        enc_chans = [channels, *enc_chans]

        dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
        dec_chans = [dec_init_chan, *dec_chans]

        enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))

        self.final_points = final_points
        self.quantize_layer = VectorQuantizer(num_tokens, codebook_dim)
        self.dim1 = dim1
        self.t1 = Transformer(
            dim = self.dim1,
            seq_len=30,
            depth=final_depth,
            heads=8,
            dim_head=64,
            rotary_emb= False,
            attn_types=('full',),
        ).cuda()
        self.t1_logits = nn.Sequential(
            nn.LayerNorm(self.dim1),
            nn.Linear(self.dim1, 512),
        ).cuda()
        self.t2 = Transformer(
            dim = 512,
            seq_len=30,
            depth=final_depth,
            heads=8,
            dim_head=64,
            rotary_emb= False,
            attn_types=('full',),
        )
        self.t2_logits = nn.Sequential(
            nn.LayerNorm(512),
            nn.Linear(512, 32),
        )
        # self.end = nn.Sequential(
        #     nn.ReLU(),
        #     nn.Linear(32*(30+self.final_points), 30*29),
        # )
        self.end = nn.Sequential(
             nn.ReLU(),
             nn.Linear(32*30, 30*29),
        )

        self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
        self.kl_div_loss_weight = kl_div_loss_weight

        # take care of normalization within class
        self.normalization = normalization

        self._register_external_parameters()

        in_channels = 6
        num_centroids=(512, self.final_points)
        radius=(0.1, 0.4)
        num_neighbours=(64, 512)
        sa_channels=((512, 512), (512, num_tokens))
        use_xyz=True
        num_sa_layers = len(num_centroids)

        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_sa_layers):
            sa_module = PointNetSAModule(in_channels=feature_channels,
                                         mlp_channels=sa_channels[ind],
                                         num_centroids=num_centroids[ind],
                                         radius=radius[ind],
                                         num_neighbours=num_neighbours[ind],
                                         use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_channels[ind][-1]
        self.reset_parameters()

    def _register_external_parameters(self):
        """Register external parameters for DeepSpeed partitioning."""
        if (
                not distributed_utils.is_distributed
                or not distributed_utils.using_backend(
                    distributed_utils.DeepSpeedBackend)
        ):
            return

        deepspeed = distributed_utils.backend.backend_module
        # deepspeed.zero.register_external_parameter(self, self.codebook.weight)

    def norm(self, images):
        if not exists(self.normalization):
            return images

        means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
        means, stds = map(lambda t: rearrange(t, 'c -> () c () ()'), (means, stds))
        images = images.clone()
        images.sub_(means).div_(stds)
        return images

    def reset_parameters(self):
        for sa_module in self.sa_modules:
            sa_module.reset_parameters(xavier_uniform)
        # self.mlp_seg.reset_parameters(xavier_uniform)
        set_bn(self, momentum=0.01)

    @torch.no_grad()
    @eval_decorator
    def get_codebook_indices(self, pgms, params):
        # logits = self(images, return_logits = True)
        # codebook_indices = logits.argmax(dim = 1).flatten(1)
        # return codebook_indices

        logits = self(pgms, params, return_logits = True)
        # _, sampled, _, _ = self.quantize_layer(logits.unsqueeze(-1))
        _, _, _, indices = self.quantize_layer(logits.unsqueeze(-1))
        return indices

    def decode(
        self,
        indices
    ):
        # image_embeds = self.codebook(img_seq)
        # b, n, d = image_embeds.shape
        # h = w = int(sqrt(n))

        # image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w)
        # images = self.decoder(image_embeds)
        # pcs = self.decoder(image_embeds.reshape(image_embeds.shape[0],-1)).reshape(image_embeds.shape[0],-1,3)
        # pcs = self.decoder(image_embeds.transpose(1,2)).reshape(image_embeds.shape[0],-1,3)
        flat_indices = indices.reshape(-1,1)
        encodings = torch.zeros(flat_indices.shape[0], self.quantize_layer._num_embeddings, device=indices.device)
        encodings.scatter_(1, flat_indices, 1)
        quantized = torch.matmul(encodings, self.quantize_layer._embedding.weight).view(indices.shape[0],indices.shape[1],1,-1).permute(0,3,1,2)
        # pcs = self.decoder(quantized.squeeze(-1)).reshape(quantized.shape[0],-1,3)
        out = self.t2_logits(self.t2(quantized.squeeze(-1).permute(0,2,1).contiguous()))
        out = out.reshape(out.shape[0],-1)
        out = self.end(out)
        out = out.reshape(out.shape[0], 30, -1)
        # return images
        return out

    def get_encoding(self, pgms, params):
        logits = self(pgms, params, return_logits = True)
        # _, sampled, _, _ = self.quantize_layer(logits.unsqueeze(-1))
        _, quantized, _, _ = self.quantize_layer(logits.unsqueeze(-1))
        return quantized

    def forward(
        self,
        pgms,
        params,
        pgms_masks = None,
        params_masks = None,
        return_loss = False,
        return_recons = False,
        return_logits = False,
        return_detailed_loss = False,
        temp = None,
        epoch = 0,
    ):
        device, num_tokens, image_size, kl_div_loss_weight = pgms.device, self.num_tokens, self.image_size, self.kl_div_loss_weight
        pgms_onehot = torch.zeros(pgms.shape[0], pgms.shape[1], pgms.shape[2], 22).cuda()
        pgms_onehot.scatter_(-1, pgms.unsqueeze(-1), 1)
        input = torch.cat((pgms_onehot,params), -1)
        input = input.reshape(pgms.shape[0], 30, 22+7)
        # input = torch.cat((pgms.unsqueeze(-1),params), -1)
        # input = input.reshape(pgms.shape[0], 30, 1+7)

        before_vq = self.t1_logits(self.t1(input))
        
        # points = pts.transpose(1,2)
        # # logits = self.encoder(img)
        # xyz = points
        # feature = points
        # for sa_module in self.sa_modules:
        #     xyz, feature, group_xyz = sa_module(xyz, feature)

        before_vq = before_vq.permute(0,2,1).contiguous()

        # [8, 512, 30 + 128]
        # before_vq = torch.cat((before_vq,feature),-1)
        
        if return_logits:
            return before_vq # return logits for getting hard image indices for DALL-E training

        vq_loss, sampled, perplexity, _ = self.quantize_layer(before_vq.unsqueeze(-1))
        
        out = self.t2_logits(self.t2(sampled.squeeze(-1).permute(0,2,1).contiguous()))
        out = out.reshape(out.shape[0],-1)
        out = self.end(out)
        out = out.reshape(out.shape[0], 30, -1)

        if not return_loss:
            return out
        
        # compute program classification loss
        bsz, n_block, n_step = pgms.size()
        out_pgm = out[:,:,:22]
        out_pgm = F.log_softmax(out_pgm, dim=-1)
        pgms = pgms.contiguous().view(bsz, n_block * n_step)
        pgms_masks = pgms_masks.contiguous().view(bsz, n_block * n_step)
        pred = to_contiguous(out_pgm).view(-1, out_pgm.size(2))
        target = to_contiguous(pgms).view(-1,1).cuda()
        mask = to_contiguous(pgms_masks).view(-1,1).cuda()
        loss_cls = - pred.gather(1, target) * mask
        loss_cls = torch.sum(loss_cls) / torch.sum(mask)
        _, idx = torch.max(pred, dim=1)
        correct = idx.eq(torch.squeeze(target))
        correct = correct.float() * torch.squeeze(mask)
        acc = torch.sum(correct) / torch.sum(mask)

        # compute parameter regression loss
        bsz, n_block, n_step, n_param = params.size()
        out_param = out[:,:,22:]
        params = params.contiguous().view(bsz, n_block * n_step, n_param).cuda()
        params_masks = params_masks.contiguous().view(bsz, n_block * n_step, n_param).cuda()
        diff = 0.5 * (out_param - params) ** 2
        diff = diff * params_masks
        loss_reg = torch.sum(diff) / torch.sum(params_masks)

        loss = loss_cls + 3*loss_reg

        if not return_recons:
            return loss

        print('loss_cls:%.3f, acc:%.3f, loss_reg:%.3f'%(loss_cls, acc, loss_reg))

        if not return_detailed_loss:
            return loss, out, perplexity
        else:
            return loss_cls, loss_reg, acc, out, perplexity

# main classes

class CLIP(nn.Module):
    def __init__(
        self,
        *,
        dim_text = 512,
        dim_image = 512,
        dim_latent = 512,
        num_text_tokens = 10000,
        text_enc_depth = 6,
        text_seq_len = 256,
        text_heads = 8,
        num_visual_tokens = 512,
        visual_enc_depth = 6,
        visual_heads = 8,
        visual_image_size = 256,
        visual_patch_size = 32,
        channels = 3
    ):
        super().__init__()
        self.text_emb = nn.Embedding(num_text_tokens, dim_text)
        self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
        self.text_transformer = Transformer(causal = False, seq_len = text_seq_len, dim = dim_text, depth = text_enc_depth, heads = text_heads, rotary_emb = False)
        self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False)

        assert visual_image_size % visual_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (visual_image_size // visual_patch_size) ** 2
        patch_dim = channels * visual_patch_size ** 2

        self.visual_patch_size = visual_patch_size
        self.to_visual_embedding = nn.Linear(patch_dim, dim_image)
        self.visual_pos_emb = nn.Embedding(num_patches, dim_image)
        self.visual_transformer = Transformer(causal = False, seq_len = num_patches, dim = dim_image, depth = visual_enc_depth, heads = visual_heads, rotary_emb = False)
        self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False)

        self.temperature = nn.Parameter(torch.tensor(1.))

    def forward(
        self,
        text,
        image,
        text_mask = None,
        return_loss = False
    ):
        b, device, p = text.shape[0], text.device, self.visual_patch_size

        text_emb = self.text_emb(text)
        text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device))

        image_patches = rearrange(image, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        image_emb = self.to_visual_embedding(image_patches)
        image_emb += self.visual_pos_emb(torch.arange(image_emb.shape[1], device = device))

        enc_text = self.text_transformer(text_emb, mask = text_mask)
        enc_image = self.visual_transformer(image_emb)

        if exists(text_mask):
            text_latents = masked_mean(enc_text, text_mask, dim = 1)
        else:
            text_latents = enc_text.mean(dim = 1)

        image_latents = enc_image.mean(dim = 1)

        text_latents = self.to_text_latent(text_latents)
        image_latents = self.to_visual_latent(image_latents)

        text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents))

        temp = self.temperature.exp()

        if not return_loss:
            sim = einsum('n d, n d -> n', text_latents, image_latents) * temp
            return sim

        sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp
        labels = torch.arange(b, device = device)
        loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
        return loss

# main DALL-E class

class DALLE(nn.Module):
    def __init__(
        self,
        *,
        dim,
        vae,
        num_text_tokens = 10000,
        text_seq_len = 256,
        depth,
        heads = 8,
        dim_head = 64,
        reversible = False,
        attn_dropout = 0.,
        ff_dropout = 0,
        sparse_attn = False,
        attn_types = None,
        loss_img_weight = 7,
        stable = False,
        sandwich_norm = False,
        shift_tokens = True,
        rotary_emb = True,
        shared_attn_ids = None,
        shared_ff_ids = None,
        share_input_output_emb = False,
        optimize_for_inference = False,
    ):
        super().__init__()
        assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'

        image_size = vae.image_size
        num_image_tokens = vae.num_tokens
        image_fmap_size = (vae.image_size // (2 ** vae.num_layers))
        # image_seq_len = image_fmap_size ** 2
        image_seq_len = vae.final_points

        num_text_tokens = num_text_tokens + text_seq_len  # reserve unique padding tokens for each position (text seq len)

        self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) if not rotary_emb else always(0) # +1 for <bos>
        # self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0)
        # self.image_pos_emb = nn.Embedding(image_seq_len, dim) if not rotary_emb else always(0)
        # AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0)

        self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss
        self.num_image_tokens = num_image_tokens

        self.text_seq_len = text_seq_len
        self.image_seq_len = image_seq_len

        seq_len = text_seq_len + image_seq_len
        # 50176 = 49664 + 512
        total_tokens = num_text_tokens + num_image_tokens
        self.total_tokens = total_tokens
        self.total_seq_len = seq_len

        self.vae = vae
        set_requires_grad(self.vae, False) # freeze VAE from being trained

        self.transformer = Transformer(
            dim = dim,
            causal = True,
            seq_len = seq_len,
            depth = depth,
            heads = heads,
            dim_head = dim_head,
            reversible = reversible,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            attn_types = attn_types,
            image_fmap_size = image_fmap_size,
            sparse_attn = sparse_attn,
            stable = stable,
            sandwich_norm = sandwich_norm,
            shift_tokens = shift_tokens,
            rotary_emb = rotary_emb,
            shared_attn_ids = shared_attn_ids,
            shared_ff_ids = shared_ff_ids,
            optimize_for_inference = optimize_for_inference,
        )

        self.stable = stable

        if stable:
            self.norm_by_max = DivideMax(dim = -1)

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, self.total_tokens),
        )

        if share_input_output_emb:
            self.text_emb = SharedEmbedding(self.to_logits[1], 0, num_text_tokens)
            self.image_emb = SharedEmbedding(self.to_logits[1], num_text_tokens, total_tokens)
        else:
            self.text_emb = nn.Embedding(num_text_tokens, dim)
            self.image_emb = nn.Embedding(num_image_tokens, dim)

        seq_range = torch.arange(seq_len)
        logits_range = torch.arange(total_tokens)

        seq_range = rearrange(seq_range, 'n -> () n ()')
        logits_range = rearrange(logits_range, 'd -> () () d')

        logits_mask = (
            ((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) |
            ((seq_range < text_seq_len) & (logits_range >= num_text_tokens))
        )

        self.register_buffer('logits_mask', logits_mask, persistent=False)
        self.loss_img_weight = loss_img_weight


    @torch.no_grad()
    @eval_decorator
    def generate_texts(
        self,
        tokenizer,
        text = None,
        *,
        filter_thres = 0.5,
        temperature = 1.
    ):
        text_seq_len = self.text_seq_len
        if text is None or text == "":
            text_tokens = torch.tensor([[0]]).cuda()
        else:
            text_tokens = torch.tensor(tokenizer.tokenizer.encode(text)).cuda().unsqueeze(0)

        for _ in range(text_tokens.shape[1], text_seq_len):
            device = text_tokens.device

            tokens = self.text_emb(text_tokens)
            tokens += self.text_pos_emb(torch.arange(text_tokens.shape[1], device = device))

            seq_len = tokens.shape[1]

            output_transf = self.transformer(tokens)

            if self.stable:
                output_transf = self.norm_by_max(output_transf)

            logits = self.to_logits(output_transf)

            # mask logits to make sure text predicts text (except last token), and image predicts image

            logits_mask = self.logits_mask[:, :seq_len]
            max_neg_value = -torch.finfo(logits.dtype).max
            logits.masked_fill_(logits_mask, max_neg_value)
            logits = logits[:, -1, :]

            filtered_logits = top_k(logits, thres = filter_thres)
            sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

            text_tokens = torch.cat((text_tokens, sample[:, None]), dim=-1)

        padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len))
        texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
        return text_tokens, texts

    @torch.no_grad()
    @eval_decorator
    def generate_images(
        self,
        text,
        *,
        clip = None,
        filter_thres = 0.5,
        temperature = 1.,
        img = None,
        num_init_img_tokens = None,
        cond_scale = 1.,
        use_cache = False,
    ):
        vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
        total_len = text_seq_len + image_seq_len

        text = text[:, :text_seq_len] # make sure text is within bounds
        out = text

        if exists(img):
            image_size = vae.image_size
            assert img.shape[1] == 3 and img.shape[2] == image_size and img.shape[3] == image_size, f'input image must have the correct image size {image_size}'

            indices = vae.get_codebook_indices(img)
            num_img_tokens = default(num_init_img_tokens, int(0.4375 * image_seq_len))  # OpenAI used 14 * 32 initial tokens to prime
            assert num_img_tokens < image_seq_len, 'number of initial image tokens for priming must be less than the total image token sequence length'

            indices = indices[:, :num_img_tokens]
            out = torch.cat((out, indices), dim = -1)

        prev_cache = None
        cache = {} if use_cache else None
        for cur_len in range(out.shape[1], total_len):
            is_image = cur_len >= text_seq_len

            text, image = out[:, :text_seq_len], out[:, text_seq_len:]
            # print(cur_len, text.shape, image.shape)

            if cond_scale != 1 and use_cache:
                # copy the cache state to infer from the same place twice
                prev_cache = cache.copy()

            logits = self(text, image, cache = cache)

            if cond_scale != 1:
                # discovery by Katherine Crowson
                # https://twitter.com/RiversHaveWings/status/1478093658716966912
                null_cond_logits = self(text, image, null_cond_prob = 1., cache = prev_cache)
                logits = null_cond_logits + (logits - null_cond_logits) * cond_scale

            logits = logits[:, -1, :]

            filtered_logits = top_k(logits, thres = filter_thres)
            sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

            sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
            out = torch.cat((out, sample[:, None]), dim=-1)

        text_seq = out[:, :text_seq_len]

        img_seq = out[:, -image_seq_len:]
        images = vae.decode(img_seq)

        if exists(clip):
            scores = clip(text_seq, images, return_loss = False)
            return images, scores

        return images

    def forward(
        self,
        text,
        image = None,
        return_loss = False,
        null_cond_prob = 0.,
        cache = None,
    ):
        assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
        batch, device, total_seq_len = text.shape[0], text.device, self.total_seq_len

        # randomly remove text condition with <null_cond_prob> probability

        if null_cond_prob > 0:
            null_mask = prob_mask_like((batch,), null_cond_prob, device = device)
            text *= rearrange(~null_mask, 'b -> b 1')

        # make sure padding in text tokens get unique padding token id

        text_range = torch.arange(self.text_seq_len, device = device) + (self.num_text_tokens - self.text_seq_len)
        text = torch.where(text == 0, text_range, text)

        # add <bos>

        text = F.pad(text, (1, 0), value = 0)

        tokens = self.text_emb(text)
        tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device))

        seq_len = tokens.shape[1]

        if exists(image) and not is_empty(image):
            is_raw_image = len(image.shape) == 3

            if is_raw_image:
                # image_size = self.vae.image_size
                # assert tuple(image.shape[1:]) == (3, image_size, image_size), f'invalid image of dimensions {image.shape} passed in during training'

                image = self.vae.get_codebook_indices(image)

            image_len = image.shape[1]
            image_emb = self.image_emb(image)

            # image_emb += self.image_pos_emb(image_emb)

            tokens = torch.cat((tokens, image_emb), dim = 1)

            seq_len += image_len

        # when training, if the length exceeds the total text + image length
        # remove the last token, since it needs not to be trained

        if tokens.shape[1] > total_seq_len:
            seq_len -= 1
            tokens = tokens[:, :-1]

        if self.stable:
            alpha = 0.1
            tokens = tokens * alpha + tokens.detach() * (1 - alpha)

        if exists(cache) and cache.get('offset'):
            tokens = tokens[:, -1:]

        #tokens.shape: [24, 320, 512]
        out = self.transformer(tokens, cache=cache)
        #out [24, 320, 512]


        if self.stable:
            out = self.norm_by_max(out)

        # out.shape: [4, 273, 512]
        # logits.shape: [4, 273, 50176]
        logits = self.to_logits(out)

        # mask logits to make sure text predicts text (except last token), and image predicts image

        logits_mask = self.logits_mask[:, :seq_len]
        if exists(cache) and cache.get('offset'):
            logits_mask = logits_mask[:, -1:]
        max_neg_value = -torch.finfo(logits.dtype).max
        logits.masked_fill_(logits_mask, max_neg_value)

        if exists(cache):
            cache['offset'] = cache.get('offset', 0) + logits.shape[1]

        if not return_loss:
            return logits

        assert exists(image), 'when training, image must be supplied'

        offsetted_image = image + self.num_text_tokens
        labels = torch.cat((text[:, 1:], offsetted_image), dim = 1)

        logits = rearrange(logits, 'b n c -> b c n')

        loss_text = F.cross_entropy(logits[:, :, :self.text_seq_len], labels[:, :self.text_seq_len])
        loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:])

        loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)
        return loss

def map_pgcode(pgms, params):
    # params [-27, 29] to --> [0, 56]
    # pgms [0, 20] to --> [57, 77]
    params2 = (params+27).type(torch.int)
    pgms2 = pgms + 57
    bs = params2.shape[0]
    return torch.cat((pgms2.reshape(bs,-1), params2.reshape(bs, -1)), -1)
def decode_pgcode(pg_codes):
    # params [-27, 29] to --> [0, 56]
    # pgms [0, 20] to --> [57, 77]
    return (pg_codes[:, :30]-57, pg_codes[:,30:]-27)

def map_pgcode2(pgms, params):
    #Chair
    # params [-27, 29] to --> [0, 56]
    # pgms [0, 20] to --> [57, 77]
    #Table
    # params [-12, 24] +27 --> [15, 51]
    # pgms [0, 20] +57--> [57, 77]
    # Fuse two
    bs = params.shape[0]
    pgm_idx = (torch.arange(30)*8)
    indicator = torch.ones(240)
    indicator[pgm_idx] = 0
    params_idx = torch.arange(240)[indicator == 1]
    placeholder = torch.zeros(bs, 240).type(torch.LongTensor).cuda()
    placeholder[:,pgm_idx] = pgms.reshape(bs, -1) + 57
    placeholder[:,params_idx] = (params+27).type(torch.LongTensor).reshape(bs,-1).cuda()

    return placeholder, pgm_idx, params_idx
def decode_pgcode2(pg_codes):
    # params [-27, 29] to --> [0, 56]
    # pgms [0, 20] to --> [57, 77]
    # bs = pg_codes.shape[0]
    pgm_idx = (torch.arange(30)*8)
    indicator = torch.ones(240)
    indicator[pgm_idx] = 0
    params_idx = torch.arange(240)[indicator == 1]

    return (pg_codes[:, pgm_idx]- 57, pg_codes[:, params_idx] - 27)
    # return (pg_codes[:, pgm_idx], pg_codes[:, params_idx] - 27)


class DALLE_PG_Discrete(nn.Module):
    def __init__(
        self,
        *,
        dim,
        vae,
        num_text_tokens = 10000,
        text_seq_len = 256,
        depth,
        heads = 8,
        dim_head = 64,
        reversible = False,
        attn_dropout = 0.,
        ff_dropout = 0,
        sparse_attn = False,
        attn_types = None,
        loss_img_weight = 7,
        stable = False,
        sandwich_norm = False,
        shift_tokens = True,
        rotary_emb = True,
        shared_attn_ids = None,
        shared_ff_ids = None,
        share_input_output_emb = False,
        optimize_for_inference = False,
        inverse = False,
    ):
        super().__init__()
        # assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'

        image_size = vae.image_size
        num_image_tokens = vae.num_tokens
        image_fmap_size = (vae.image_size // (2 ** vae.num_layers))
        # image_seq_len = image_fmap_size ** 2
        image_seq_len = vae.final_points

        num_text_tokens = num_text_tokens + text_seq_len  # reserve unique padding tokens for each position (text seq len)

        pg_seq_len = 240
        self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) if not rotary_emb else always(0) # +1 for <bos>
        # self.image_intext_pos_emb = nn.Embedding(image_seq_len, dim) if not rotary_emb else always(0)
        self.pg_pos_emb = nn.Embedding(pg_seq_len, dim) if not rotary_emb else always(0)
        self.image_pos_emb = nn.Embedding(image_seq_len + 1, dim) if not rotary_emb else always(0)
        self.image_pos_emb2 = nn.Embedding(image_seq_len + 1, dim) if not rotary_emb else always(0)
        # self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0)
        # self.image_pos_emb = nn.Embedding(image_seq_len, dim) if not rotary_emb else always(0)
        # AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0)

        self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss
        self.num_image_tokens = num_image_tokens
        num_pg_tokens = 78
        self.num_pg_tokens = num_pg_tokens

        self.text_seq_len = text_seq_len
        self.image_seq_len = image_seq_len

        seq_len = text_seq_len + image_seq_len
        # 50176 = 49664 + 512
        self.total_seq_len = seq_len

        self.pg_seq_len = pg_seq_len
        self.total_pg_seq_len = image_seq_len + pg_seq_len
        total_tokens = num_text_tokens + num_image_tokens
        self.total_tokens = total_tokens
        total_pg_tokens = num_pg_tokens + num_image_tokens
        self.total_pg_tokens = total_pg_tokens

        self.vae = vae
        set_requires_grad(self.vae, False) # freeze VAE from being trained

        self.transformer = Transformer(
            dim = dim,
            causal = True,
            # seq_len = seq_len,
            seq_len = self.total_pg_seq_len,
            depth = depth,
            heads = heads,
            dim_head = dim_head,
            reversible = reversible,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            attn_types = attn_types,
            image_fmap_size = image_fmap_size,
            sparse_attn = sparse_attn,
            stable = stable,
            sandwich_norm = sandwich_norm,
            shift_tokens = shift_tokens,
            rotary_emb = rotary_emb,
            shared_attn_ids = shared_attn_ids,
            shared_ff_ids = shared_ff_ids,
            optimize_for_inference = optimize_for_inference,
        )
        # self.transformer2 = Transformer(
        #     dim = dim,
        #     causal = True,
        #     seq_len = seq_len,
        #     depth = depth,
        #     heads = heads,
        #     dim_head = dim_head,
        #     reversible = reversible,
        #     attn_dropout = attn_dropout,
        #     ff_dropout = ff_dropout,
        #     attn_types = attn_types,
        #     image_fmap_size = image_fmap_size,
        #     sparse_attn = sparse_attn,
        #     stable = stable,
        #     sandwich_norm = sandwich_norm,
        #     shift_tokens = shift_tokens,
        #     rotary_emb = rotary_emb,
        #     shared_attn_ids = shared_attn_ids,
        #     shared_ff_ids = shared_ff_ids,
        #     optimize_for_inference = optimize_for_inference,
        # )

        self.stable = stable

        if stable:
            self.norm_by_max = DivideMax(dim = -1)

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, self.total_tokens),
        )

        self.to_pg_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, self.total_pg_tokens),
        )

        if share_input_output_emb:
            self.text_emb = SharedEmbedding(self.to_logits[1], 0, num_text_tokens)
            self.image_emb = SharedEmbedding(self.to_logits[1], num_text_tokens, total_tokens)
        else:
            self.text_emb = nn.Embedding(num_text_tokens, dim)
            self.image_emb = nn.Embedding(num_image_tokens, dim)
            self.pgm_emb = nn.Embedding(num_pg_tokens, dim)

        seq_range = torch.arange(seq_len)
        logits_range = torch.arange(total_tokens)

        seq_range = rearrange(seq_range, 'n -> () n ()')
        logits_range = rearrange(logits_range, 'd -> () () d')

        logits_mask = (
            ((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) |
            ((seq_range < text_seq_len) & (logits_range >= num_text_tokens))
        )
        self.register_buffer('logits_mask', logits_mask, persistent=False)

        logits_mask_inverse = (
            ((seq_range >= image_seq_len) & (logits_range < num_image_tokens)) |
            ((seq_range < image_seq_len) & (logits_range >= num_image_tokens))
        )
        self.register_buffer('logits_mask_inverse', logits_mask_inverse, persistent=False)

        pgpts_seq_len = image_seq_len + pg_seq_len
        pg_seq_range = torch.arange(pgpts_seq_len)
        pg_logits_range = torch.arange(total_pg_tokens)
        pg_seq_range = rearrange(pg_seq_range, 'n -> () n ()')
        pg_logits_range = rearrange(pg_logits_range, 'd -> () () d')
        pg_logits_mask = (
            ((pg_seq_range >= image_seq_len) & (pg_logits_range < num_image_tokens)) |
            ((pg_seq_range < image_seq_len) & (pg_logits_range >= num_image_tokens))
        )
        self.register_buffer('pg_logits_mask', pg_logits_mask, persistent=False)

        pg_logits_mask_inverse = (
            ((pg_seq_range >= pg_seq_len) & (pg_logits_range < num_pg_tokens)) |
            ((pg_seq_range < pg_seq_len) & (pg_logits_range >= num_pg_tokens))
        )
        self.register_buffer('pg_logits_mask_inverse', pg_logits_mask_inverse, persistent=False)

        self.loss_img_weight = loss_img_weight


    @torch.no_grad()
    @eval_decorator
    def generate_texts(
        self,
        tokenizer,
        text = None,
        *,
        filter_thres = 0.5,
        temperature = 1.
    ):
        text_seq_len = self.text_seq_len
        if text is None or text == "":
            text_tokens = torch.tensor([[0]]).cuda()
        else:
            # text_tokens = torch.tensor(tokenizer.tokenizer.encode(text)).cuda().unsqueeze(0)
            text_tokens = torch.tensor(tokenizer.encode(text)).cuda().unsqueeze(0)

        for _ in range(text_tokens.shape[1], text_seq_len):
            device = text_tokens.device

            tokens = self.text_emb(text_tokens)
            tokens += self.text_pos_emb(torch.arange(text_tokens.shape[1], device = device))

            seq_len = tokens.shape[1]

            output_transf = self.transformer(tokens)

            if self.stable:
                output_transf = self.norm_by_max(output_transf)

            logits = self.to_logits(output_transf)

            # mask logits to make sure text predicts text (except last token), and image predicts image

            logits_mask = self.logits_mask[:, :seq_len]
            max_neg_value = -torch.finfo(logits.dtype).max
            logits.masked_fill_(logits_mask, max_neg_value)
            logits = logits[:, -1, :]

            filtered_logits = top_k(logits, thres = filter_thres)
            sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

            text_tokens = torch.cat((text_tokens, sample[:, None]), dim=-1)

        padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len))
        # texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
        texts = [tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
        return text_tokens, texts

    @torch.no_grad()
    @eval_decorator
    def generate_images(
        self,
        text,
        *,
        clip = None,
        filter_thres = 0.5,
        temperature = 1.,
        img = None,
        num_init_img_tokens = None,
        cond_scale = 1.,
        use_cache = False,
    ):
        vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
        total_len = text_seq_len + image_seq_len

        text = text[:, :text_seq_len] # make sure text is within bounds
        out = text

        if exists(img):
            image_size = vae.image_size
            assert img.shape[1] == 3 and img.shape[2] == image_size and img.shape[3] == image_size, f'input image must have the correct image size {image_size}'

            indices = vae.get_codebook_indices(img)
            num_img_tokens = default(num_init_img_tokens, int(0.4375 * image_seq_len))  # OpenAI used 14 * 32 initial tokens to prime
            assert num_img_tokens < image_seq_len, 'number of initial image tokens for priming must be less than the total image token sequence length'

            indices = indices[:, :num_img_tokens]
            out = torch.cat((out, indices), dim = -1)

        prev_cache = None
        cache = {} if use_cache else None
        for cur_len in range(out.shape[1], total_len):
            is_image = cur_len >= text_seq_len

            text, image = out[:, :text_seq_len], out[:, text_seq_len:]
            # print(cur_len, text.shape, image.shape)

            if cond_scale != 1 and use_cache:
                # copy the cache state to infer from the same place twice
                prev_cache = cache.copy()

            logits = self(text, image, cache = cache)

            if cond_scale != 1:
                # discovery by Katherine Crowson
                # https://twitter.com/RiversHaveWings/status/1478093658716966912
                null_cond_logits = self(text, image, null_cond_prob = 1., cache = prev_cache)
                logits = null_cond_logits + (logits - null_cond_logits) * cond_scale

            logits = logits[:, -1, :]

            filtered_logits = top_k(logits, thres = filter_thres)
            sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

            sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
            out = torch.cat((out, sample[:, None]), dim=-1)

        text_seq = out[:, :text_seq_len]

        img_seq = out[:, -image_seq_len:]
        images = vae.decode(img_seq)

        if exists(clip):
            scores = clip(text_seq, images, return_loss = False)
            return images, scores

        return images

    def forward(
        self,
        text = None,
        image = None,
        pg_data = None,
        return_loss = False,
        null_cond_prob = 0.,
        cache = None,
        pg_train = True,
        pg_infer = False,
        inverse = False,
        fixed_pos = False,
        discrete_type = 1,
        pgm_only = False,
    ):
        debug = False
        if debug:
            from pytorch3d.io import save_ply
            import os
            save_dir = './shape2prog/vqprogram_outputs/test100'
            pts_save_dir = './shape2prog/vqprogram_outputs/test100/pts'
            pg_pts, pgms, pgms_masks, params, params_masks = pg_data[0].cuda(), pg_data[1].cuda(), pg_data[2].cuda(), pg_data[3].cuda(), pg_data[4].cuda()
            save_obj = {
                'pgm': pgms,
                'param': params,
            }
            torch.save(save_obj, os.path.join(save_dir,'%04d'%(0)+'.pt'))
            pc = pg_pts
            for i in range(pc.shape[0]):
                save_ply(os.path.join(pts_save_dir,'%04d'%i+'_ori.ply'), pc[i])

        if inverse:
            text2 = text.clone()
            image2 = image.clone()
        if not pg_infer and not pgm_only:
            assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
            batch, device, total_seq_len = text.shape[0], text.device, self.total_seq_len

            # randomly remove text condition with <null_cond_prob> probability

            if null_cond_prob > 0:
                null_mask = prob_mask_like((batch,), null_cond_prob, device = device)
                text *= rearrange(~null_mask, 'b -> b 1')

            # make sure padding in text tokens get unique padding token id

            text_range = torch.arange(self.text_seq_len, device = device) + (self.num_text_tokens - self.text_seq_len)
            text = torch.where(text == 0, text_range, text)

            # add <bos>

            text = F.pad(text, (1, 0), value = 0)

            tokens = self.text_emb(text)
            tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device))

            seq_len = tokens.shape[1]

            if exists(image) and not is_empty(image):
                is_raw_image = len(image.shape) == 3

                if is_raw_image:
                    # image_size = self.vae.image_size
                    # assert tuple(image.shape[1:]) == (3, image_size, image_size), f'invalid image of dimensions {image.shape} passed in during training'

                    image = self.vae.get_codebook_indices(image)

                image_len = image.shape[1]
                image_emb = self.image_emb(image)

                image_emb += self.image_pos_emb(torch.arange(image_emb.shape[1], device = device) + 1)

                tokens = torch.cat((tokens, image_emb), dim = 1)

                seq_len += image_len

            # when training, if the length exceeds the total text + image length
            # remove the last token, since it needs not to be trained

            if tokens.shape[1] > total_seq_len:
                seq_len -= 1
                tokens = tokens[:, :-1]

            if self.stable:
                alpha = 0.1
                tokens = tokens * alpha + tokens.detach() * (1 - alpha)

            if exists(cache) and cache.get('offset'):
                tokens = tokens[:, -1:]

            #tokens.shape: [24, 320, 512]
            out = self.transformer(tokens, cache=cache)
            #out [24, 320, 512]

            if self.stable:
                out = self.norm_by_max(out)

            # out.shape: [4, 273, 512]
            # logits.shape: [4, 273, 50176]
            logits = self.to_logits(out)

            # mask logits to make sure text predicts text (except last token), and image predicts image

            logits_mask = self.logits_mask[:, :seq_len]
            if exists(cache) and cache.get('offset'):
                logits_mask = logits_mask[:, -1:]
            max_neg_value = -torch.finfo(logits.dtype).max
            logits.masked_fill_(logits_mask, max_neg_value)

            if exists(cache):
                cache['offset'] = cache.get('offset', 0) + logits.shape[1]

            if not return_loss:
                return logits

            assert exists(image), 'when training, image must be supplied'

            offsetted_image = image + self.num_text_tokens
            labels = torch.cat((text[:, 1:], offsetted_image), dim = 1)

            logits = rearrange(logits, 'b n c -> b c n')

            loss_text = F.cross_entropy(logits[:, :, :self.text_seq_len], labels[:, :self.text_seq_len])
            loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:])

            # loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)
            loss = loss_text + loss_img

        if (pg_train and not pg_infer) or pgm_only:
            pg_pts, pgms, pgms_masks, params, params_masks = pg_data[0].cuda(), pg_data[1].cuda(), pg_data[2].cuda(), pg_data[3].cuda(), pg_data[4].cuda()
            device = pg_pts.device
            # pg_pts = normalize_points_torch(pg_pts)
            
            pts_code = self.vae.get_codebook_indices(pg_pts)
            # why?
            # pts_range = torch.arange(self.image_seq_len, device = device) + (self.num_image_tokens - self.image_seq_len)
            # pts_code = torch.where(pts_code == 0, pts_range, pts_code)

            pts_code = F.pad(pts_code, (1, 0), value = 0)
            pts_emb = self.image_emb(pts_code)
            pts_emb += self.image_pos_emb2(torch.arange(pts_emb.shape[1], device = device))
            pg_seq_len = pts_code.shape[1]

            # pg_code = self.pgvae.get_codebook_indices(pgms, params)
            if discrete_type == 1:
                pg_code = map_pgcode(pgms, params)
            elif discrete_type == 2:
                pg_code, pgm_idx, param_idx = map_pgcode2(pgms, params)
            else:
                assert NameError('non-exist discrete way')

            #debug
            # out = self.pgvae.decode(pg_code)
            # id = 2
            # pgms_our = torch.argmax(out[id,:,:22].reshape(10,3,-1),-1)
            # pgms_in = pgms[id,:]
            
            pg_emb = self.pgm_emb(pg_code)
            pg_emb += self.pg_pos_emb(torch.arange(pg_emb.shape[1], device = device))
            pg_tokens = torch.cat((pts_emb, pg_emb), dim = 1)
            pg_seq_len += pg_emb.shape[1]
            
            if pg_seq_len > self.total_pg_seq_len:
                pg_seq_len -= 1
                pg_tokens = pg_tokens[:, :-1]

            pg_out = self.transformer(pg_tokens)
            # pg_out = self.transformer2(pg_tokens)
            pg_logits = self.to_pg_logits(pg_out)
            
            pg_logits_mask = self.pg_logits_mask[:, :pg_seq_len]
            max_neg_value = -torch.finfo(pg_logits.dtype).max
            pg_logits.masked_fill_(pg_logits_mask, max_neg_value)

            if not fixed_pos:
                if discrete_type == 1:
                    offsetted_pg_code = pg_code + self.num_image_tokens
                    pg_labels = torch.cat((pts_code[:, 1:], offsetted_pg_code), dim = 1)
                    pg_logits = rearrange(pg_logits, 'b n c -> b c n')

                    loss_pts = F.cross_entropy(pg_logits[:, :, :self.image_seq_len], pg_labels[:, :self.image_seq_len])
                    loss_pg = F.cross_entropy(pg_logits[:, :, self.image_seq_len:], pg_labels[:, self.image_seq_len:])
                elif discrete_type == 2:
                    offsetted_pg_code = pg_code + self.num_image_tokens
                    pg_labels = torch.cat((pts_code[:, 1:], offsetted_pg_code), dim = 1)
                    pg_logits = rearrange(pg_logits, 'b n c -> b c n')
                    loss_pts = F.cross_entropy(pg_logits[:, :, :self.image_seq_len], pg_labels[:, :self.image_seq_len])

                    pgms = pg_logits[:,512:512+21,self.image_seq_len:][:,:,pgm_idx]
                    pred = F.log_softmax(pgms, dim=1)
                    target = pg_code[:, pgm_idx] - 57
                    pgms_masks = pgms_masks.contiguous().view(pgms.shape[0], -1)
                    loss_cls = - pred.gather(1, target.unsqueeze(1)) * pgms_masks.reshape(pgms.shape[0],1,-1)
                    loss_pg1 = torch.sum(loss_cls) / torch.sum(pgms_masks)

                    params = pg_logits[:,512+21:,self.image_seq_len:][:,:,param_idx]
                    pred = F.log_softmax(params, dim=1)
                    target = pg_code[:, param_idx]
                    params_masks = params_masks.contiguous().view(pgms.shape[0], -1)
                    loss_reg = - pred.gather(1, target.unsqueeze(1)) * params_masks.reshape(pgms.shape[0],1,-1)
                    loss_pg2 = torch.sum(loss_reg) / torch.sum(params_masks)

                    # loss_pg1 = F.cross_entropy(pg_logits[:,512:512+21,self.image_seq_len:][:,:,pgm_idx], pg_code[:, pgm_idx])
                    # loss_pg2 = F.cross_entropy(pg_logits[:,512+21:,self.image_seq_len:][:,:,param_idx], pg_code[:, param_idx])
                    loss_pg = loss_pg1 + 3*loss_pg2
                    # loss_pg = (loss_pg1 + 5*loss_pg2)/6

                else:
                    NameError('non-exist type')

            else:
                neg_filled_value = torch.ones(pg_logits.shape[0], pg_logits.shape[1], self.num_text_tokens).cuda()
                filled_logits = torch.cat((neg_filled_value,pg_logits),-1)
                offsetted_pg_code = pg_code + self.num_image_tokens + self.num_text_tokens
                pg_labels = torch.cat((pts_code[:, 1:] + self.num_text_tokens, offsetted_pg_code), dim = 1)

                pg_logits = rearrange(filled_logits, 'b n c -> b c n')

                loss_pts = F.cross_entropy(pg_logits[:, :, :self.image_seq_len], pg_labels[:, :self.image_seq_len])
                loss_pg = F.cross_entropy(pg_logits[:, :, self.image_seq_len:], pg_labels[:, self.image_seq_len:])


            # loss = 1/2*loss + 1/2*((loss_pg + self.loss_img_weight *loss_pts)/ (self.loss_img_weight +1))
            # loss = (loss_pg + self.loss_img_weight *loss_pts)/ (self.loss_img_weight +1)
            # loss = loss_pg + loss_pts
            if pgm_only:
                loss = loss_pg + loss_pts
            else:
                loss = loss + loss_pg + loss_pts
            # loss = loss + loss_pg + self.loss_img_weight * loss_pts
            # loss = (loss_pg + loss_pts)/ (1+1)
            if pgm_only:
                print('total_loss:%.3f, loss_pg:%.3f, loss_pg1:%.3f, loss_pg2:%.3f, loss_pts:%.3f'%(loss, loss_pg, loss_pg1, loss_pg2, loss_pts))
            else:
                if discrete_type == 1:
                    print('total_loss:%.3f, loss_text:%.3f, loss_img:%d x %.3f, loss_pg:%.3f, loss_pts:%.3f'%(loss, loss_text, self.loss_img_weight, loss_img, loss_pg, loss_pts))
                else:
                    print('total_loss:%.3f, loss_text:%.3f, loss_img:%d x %.3f, loss_pg:%.3f, loss_pg1:%.3f, loss_pg2:%.3f, loss_pts:%.3f'%(loss, loss_text, self.loss_img_weight, loss_img, loss_pg, loss_pg1, loss_pg2, loss_pts))

        if inverse:
            text = text2
            image = image2
            batch, device, total_seq_len = text.shape[0], text.device, self.total_seq_len

            # randomly remove text condition with <null_cond_prob> probability

            if null_cond_prob > 0:
                null_mask = prob_mask_like((batch,), null_cond_prob, device = device)
                text *= rearrange(~null_mask, 'b -> b 1')

            # make sure padding in text tokens get unique padding token id

            text_range = torch.arange(self.text_seq_len, device = device) + (self.num_text_tokens - self.text_seq_len)
            text = torch.where(text == 0, text_range, text)

            image = self.vae.get_codebook_indices(image)
            image = F.pad(image, (1, 0), value = 0)
            image_emb = self.image_emb(image)
            image_emb += self.image_pos_emb(torch.arange(image_emb.shape[1], device = device))

            tokens = self.text_emb(text)
            tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device)+1)
            tokens = torch.cat((image_emb, tokens), dim = 1)
            seq_len = tokens.shape[1]
            if tokens.shape[1] > total_seq_len:
                seq_len -= 1
                tokens = tokens[:, :-1]
            out = self.transformer(tokens, cache=cache)
            logits = self.to_logits(out)

            logits_mask = self.logits_mask_inverse[:, :seq_len]
            if exists(cache) and cache.get('offset'):
                logits_mask = logits_mask[:, -1:]
            max_neg_value = -torch.finfo(logits.dtype).max
            logits.masked_fill_(logits_mask, max_neg_value)
            
            offsetted_text = text + self.num_image_tokens
            labels = torch.cat((image[:, 1:], offsetted_text), dim = 1)
            logits = rearrange(logits, 'b n c -> b c n')

            loss_img2 = F.cross_entropy(logits[:, :, :self.image_seq_len], labels[:, :self.image_seq_len])
            loss_text2 = F.cross_entropy(logits[:, :, self.image_seq_len:], labels[:, self.image_seq_len:])

            loss = loss + loss_img2 + loss_text2
            print('total_loss2:%.3f, loss_text2:%.3f, loss_img2:%d x %.3f, '%(loss, loss_text2, 1, loss_img2))


            if False:

                ## loss 1.
            
                pg_pts, pgms, _, params, _ = pg_data[0].cuda(), pg_data[1].cuda(), pg_data[2].cuda(), pg_data[3].cuda(), pg_data[4].cuda()
                # pg_pts = normalize_points_torch(pg_pts)
            
                pts_code = self.vae.get_codebook_indices(pg_pts)
                pts_emb = self.image_emb(pts_code)
                pg_seq_len = pts_code.shape[1]

                # pg_code = self.pgvae.get_codebook_indices(pgms, params)
                pg_code = map_pgcode(pgms, params)
                pg_code = F.pad(pg_code, (1, 0), value = 0)
            
                #debug
                # out = self.pgvae.decode(pg_code)
                # id = 2
                # pgms_our = torch.argmax(out[id,:,:22].reshape(10,3,-1),-1)
                # pgms_in = pgms[id,:]
            
                pg_emb = self.pgm_emb(pg_code)
                pg_tokens = torch.cat((pg_emb, pts_emb), dim = 1)
                pg_seq_len += pg_emb.shape[1]
            
                if pg_seq_len > self.total_pg_seq_len:
                    pg_seq_len -= 1
                    pg_tokens = pg_tokens[:, :-1]

                pg_out = self.transformer(pg_tokens)
                # pg_out = self.transformer2(pg_tokens)
                pg_logits = self.to_pg_logits(pg_out)
            
                pg_logits_mask = self.pg_logits_mask_inverse[:, :pg_seq_len]
                max_neg_value = -torch.finfo(pg_logits.dtype).max
                pg_logits.masked_fill_(pg_logits_mask, max_neg_value)

                offsetted_pts_code = pts_code + self.num_pg_tokens
                pg_labels = torch.cat((pg_code[:, 1:], offsetted_pts_code), dim = 1)

                pg_logits = rearrange(pg_logits, 'b n c -> b c n')

                loss_pg2 = F.cross_entropy(pg_logits[:, :, :self.pg_seq_len], pg_labels[:, :self.pg_seq_len])
                loss_pts2 = F.cross_entropy(pg_logits[:, :, self.pg_seq_len:], pg_labels[:, self.pg_seq_len:])


                loss = 1/2*loss + 1/4*(loss_pg2 + loss_pts2 + loss_img2 + loss_text2)
                print('total_loss2:%.3f, loss_text2:%.3f, loss_img2:%d x %.3f, loss_pg2:%.3f, loss_pts2:%.3f'%(loss, loss_text2, 1, loss_img2, loss_pg2, loss_pts2))



        return loss

    @torch.no_grad()
    @eval_decorator
    def generate_pgs(
        self,
        pg_pts,
        pg = None,
        *,
        filter_thres = 0.5,
        temperature = 1.
    ):
        image_seq_len, pg_seq_len, num_image_tokens = self.image_seq_len, self.pg_seq_len, self.num_image_tokens
        total_len = image_seq_len + pg_seq_len
        
        pts_code = self.vae.get_codebook_indices(pg_pts)
        out = pts_code
        pgm_idx = (torch.arange(30)*8)
        indicator = torch.ones(240)
        indicator[pgm_idx] = 0
        params_idx = torch.arange(240)[indicator == 1]

        for cur_len in range(out.shape[1], total_len):
            # is_pg = cur_len >= image_seq_len

            pts_code, pg_code = out[:, :image_seq_len], out[:, image_seq_len:]
            
            # pts_range = torch.arange(self.image_seq_len, device = pg_pts.device) + (self.num_image_tokens - self.image_seq_len)
            # pts_code = torch.where(pts_code == 0, pts_range, pts_code)
            pts_code = F.pad(pts_code, (1, 0), value = 0)
            pts_emb = self.image_emb(pts_code)
            pts_emb += self.image_pos_emb2(torch.arange(pts_emb.shape[1], device = pts_emb.device))
            # pts_emb += self.image_pos_emb(torch.arange(pts_emb.shape[1], device = pts_emb.device))
            pg_emb = self.pgm_emb(pg_code)
            pg_emb += self.pg_pos_emb(torch.arange(pg_emb.shape[1], device = pts_emb.device))

            ptspg_emb = torch.cat((pts_emb, pg_emb), dim = 1)
            cur_pg_seq_len = ptspg_emb.shape[1]
            if ptspg_emb.shape[1] > total_len:
                cur_pg_seq_len -= 1
                ptspg_emb = ptspg_emb[:, :-1]
            
            cur_out = self.transformer(ptspg_emb)
            cur_logits = self.to_pg_logits(cur_out)
            pg_logits_mask = self.pg_logits_mask[:, :cur_pg_seq_len]
            max_neg_value = -torch.finfo(cur_logits.dtype).max
            cur_logits.masked_fill_(pg_logits_mask, max_neg_value)

            logits = cur_logits[:, -1, :]

            filtered_logits = top_k(logits, thres = 0.9)
            if cur_len - 128 in pgm_idx:
                sample = gumbel_sample(filtered_logits[:, 512:512+21], temperature = temperature, dim = -1) + 57
                # sample = gumbel_sample(filtered_logits[:, 512:512+21], temperature = temperature, dim = -1)
            else:
                sample = gumbel_sample(filtered_logits[:, 512+21:], temperature = temperature, dim = -1)

            # sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
            # sample -= (num_image_tokens if is_pg else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
            out = torch.cat((out, sample[:, None]), dim=-1)
        
        pg_seq = out[:, -pg_seq_len:]
        return decode_pgcode2(pg_seq)



    @torch.no_grad()
    @eval_decorator
    def generate_text_cond_pts(
        self,
        tokenizer,
        pg_pts,
        pg = None,
        *,
        filter_thres = 0.5,
        temperature = 1.
    ):
        image_seq_len, text_seq_len, num_image_tokens = self.image_seq_len, self.text_seq_len, self.num_image_tokens
        total_len = image_seq_len + text_seq_len
        
        pts_code = self.vae.get_codebook_indices(pg_pts)
        out = pts_code

        for cur_len in range(out.shape[1], total_len):
            # is_pg = cur_len >= image_seq_len

            pts_code, text_code = out[:, :image_seq_len], out[:, image_seq_len:]
            
            # pts_range = torch.arange(self.image_seq_len, device = pg_pts.device) + (self.num_image_tokens - self.image_seq_len)
            # pts_code = torch.where(pts_code == 0, pts_range, pts_code)
            pts_code = F.pad(pts_code, (1, 0), value = 0)
            pts_emb = self.image_emb(pts_code)
            pts_emb += self.image_pos_emb(torch.arange(pts_emb.shape[1], device = pts_emb.device))
            # pts_emb += self.image_pos_emb(torch.arange(pts_emb.shape[1], device = pts_emb.device))
            text_emb = self.text_emb(text_code)
            text_emb += self.text_pos_emb(torch.arange(text_emb.shape[1], device = pts_emb.device)+1)

            ptstext_emb = torch.cat((pts_emb, text_emb), dim = 1)
            cur_ptstext_seq_len = ptstext_emb.shape[1]
            if ptstext_emb.shape[1] > total_len:
                cur_ptstext_seq_len -= 1
                ptstext_emb = ptstext_emb[:, :-1]
            
            cur_out = self.transformer(ptstext_emb)
            # cur_logits = self.to_logits_inverse(cur_out)
            cur_logits = self.to_logits(cur_out)
            ptstext_logits_mask = self.logits_mask_inverse[:, :cur_ptstext_seq_len]
            max_neg_value = -torch.finfo(cur_logits.dtype).max
            cur_logits.masked_fill_(ptstext_logits_mask, max_neg_value)

            logits = cur_logits[:, -1, :]

            filtered_logits = top_k(logits, thres = 0.9)
            sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
            sample -= num_image_tokens
            out = torch.cat((out, sample[:, None]), dim=-1)
        
        text_tokens = out[:, image_seq_len:]
        padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len))
        # texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
        texts = [tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
        return text_tokens, texts







class DALLE_PG(nn.Module):
    def __init__(
        self,
        *,
        dim,
        vae,
        pgvae,
        num_text_tokens = 10000,
        text_seq_len = 256,
        depth,
        heads = 8,
        dim_head = 64,
        reversible = False,
        attn_dropout = 0.,
        ff_dropout = 0,
        sparse_attn = False,
        attn_types = None,
        loss_img_weight = 7,
        stable = False,
        sandwich_norm = False,
        shift_tokens = True,
        rotary_emb = True,
        shared_attn_ids = None,
        shared_ff_ids = None,
        share_input_output_emb = False,
        optimize_for_inference = False,
    ):
        super().__init__()
        # assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'

        image_size = vae.image_size
        num_image_tokens = vae.num_tokens
        image_fmap_size = (vae.image_size // (2 ** vae.num_layers))
        # image_seq_len = image_fmap_size ** 2
        image_seq_len = vae.final_points

        num_text_tokens = num_text_tokens + text_seq_len  # reserve unique padding tokens for each position (text seq len)

        self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) if not rotary_emb else always(0) # +1 for <bos>
        # self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0)
        # self.image_pos_emb = nn.Embedding(image_seq_len, dim) if not rotary_emb else always(0)
        # AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0)

        self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss
        self.num_image_tokens = num_image_tokens

        num_pg_tokens = pgvae.num_tokens
        self.num_pg_tokens = num_pg_tokens

        self.text_seq_len = text_seq_len
        self.image_seq_len = image_seq_len

        seq_len = text_seq_len + image_seq_len
        # 50176 = 49664 + 512
        self.total_seq_len = seq_len

        pg_seq_len = 30
        self.pg_seq_len = pg_seq_len
        self.total_pg_seq_len = image_seq_len + pg_seq_len
        total_tokens = num_text_tokens + num_image_tokens
        self.total_tokens = total_tokens
        total_pg_tokens = num_pg_tokens + num_image_tokens
        self.total_pg_tokens = total_pg_tokens

        self.vae = vae
        self.pgvae = pgvae
        set_requires_grad(self.vae, False) # freeze VAE from being trained
        set_requires_grad(self.pgvae, False) # freeze VAE from being trained

        self.transformer = Transformer(
            dim = dim,
            causal = True,
            seq_len = seq_len,
            depth = depth,
            heads = heads,
            dim_head = dim_head,
            reversible = reversible,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            attn_types = attn_types,
            image_fmap_size = image_fmap_size,
            sparse_attn = sparse_attn,
            stable = stable,
            sandwich_norm = sandwich_norm,
            shift_tokens = shift_tokens,
            rotary_emb = rotary_emb,
            shared_attn_ids = shared_attn_ids,
            shared_ff_ids = shared_ff_ids,
            optimize_for_inference = optimize_for_inference,
        )
        # self.transformer2 = Transformer(
        #     dim = dim,
        #     causal = True,
        #     seq_len = seq_len,
        #     depth = depth,
        #     heads = heads,
        #     dim_head = dim_head,
        #     reversible = reversible,
        #     attn_dropout = attn_dropout,
        #     ff_dropout = ff_dropout,
        #     attn_types = attn_types,
        #     image_fmap_size = image_fmap_size,
        #     sparse_attn = sparse_attn,
        #     stable = stable,
        #     sandwich_norm = sandwich_norm,
        #     shift_tokens = shift_tokens,
        #     rotary_emb = rotary_emb,
        #     shared_attn_ids = shared_attn_ids,
        #     shared_ff_ids = shared_ff_ids,
        #     optimize_for_inference = optimize_for_inference,
        # )

        self.stable = stable

        if stable:
            self.norm_by_max = DivideMax(dim = -1)

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, self.total_tokens),
        )

        self.to_pg_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, self.total_pg_tokens),
        )

        if share_input_output_emb:
            self.text_emb = SharedEmbedding(self.to_logits[1], 0, num_text_tokens)
            self.image_emb = SharedEmbedding(self.to_logits[1], num_text_tokens, total_tokens)
        else:
            self.text_emb = nn.Embedding(num_text_tokens, dim)
            self.image_emb = nn.Embedding(num_image_tokens, dim)
            self.pgm_emb = nn.Embedding(num_pg_tokens, dim)

        seq_range = torch.arange(seq_len)
        logits_range = torch.arange(total_tokens)

        seq_range = rearrange(seq_range, 'n -> () n ()')
        logits_range = rearrange(logits_range, 'd -> () () d')

        logits_mask = (
            ((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) |
            ((seq_range < text_seq_len) & (logits_range >= num_text_tokens))
        )
        self.register_buffer('logits_mask', logits_mask, persistent=False)

        pgpts_seq_len = image_seq_len + pg_seq_len
        pg_seq_range = torch.arange(pgpts_seq_len)
        pg_logits_range = torch.arange(total_pg_tokens)
        pg_seq_range = rearrange(pg_seq_range, 'n -> () n ()')
        pg_logits_range = rearrange(pg_logits_range, 'd -> () () d')
        pg_logits_mask = (
            ((pg_seq_range >= image_seq_len) & (pg_logits_range < num_image_tokens)) |
            ((pg_seq_range < image_seq_len) & (pg_logits_range >= num_image_tokens))
        )
        self.register_buffer('pg_logits_mask', pg_logits_mask, persistent=False)

        self.loss_img_weight = loss_img_weight


    @torch.no_grad()
    @eval_decorator
    def generate_texts(
        self,
        tokenizer,
        text = None,
        *,
        filter_thres = 0.5,
        temperature = 1.
    ):
        text_seq_len = self.text_seq_len
        if text is None or text == "":
            text_tokens = torch.tensor([[0]]).cuda()
        else:
            text_tokens = torch.tensor(tokenizer.tokenizer.encode(text)).cuda().unsqueeze(0)

        for _ in range(text_tokens.shape[1], text_seq_len):
            device = text_tokens.device

            tokens = self.text_emb(text_tokens)
            tokens += self.text_pos_emb(torch.arange(text_tokens.shape[1], device = device))

            seq_len = tokens.shape[1]

            output_transf = self.transformer(tokens)

            if self.stable:
                output_transf = self.norm_by_max(output_transf)

            logits = self.to_logits(output_transf)

            # mask logits to make sure text predicts text (except last token), and image predicts image

            logits_mask = self.logits_mask[:, :seq_len]
            max_neg_value = -torch.finfo(logits.dtype).max
            logits.masked_fill_(logits_mask, max_neg_value)
            logits = logits[:, -1, :]

            filtered_logits = top_k(logits, thres = filter_thres)
            sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

            text_tokens = torch.cat((text_tokens, sample[:, None]), dim=-1)

        padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len))
        texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
        return text_tokens, texts

    @torch.no_grad()
    @eval_decorator
    def generate_images(
        self,
        text,
        *,
        clip = None,
        filter_thres = 0.5,
        temperature = 1.,
        img = None,
        num_init_img_tokens = None,
        cond_scale = 1.,
        use_cache = False,
    ):
        vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
        total_len = text_seq_len + image_seq_len

        text = text[:, :text_seq_len] # make sure text is within bounds
        out = text

        if exists(img):
            image_size = vae.image_size
            assert img.shape[1] == 3 and img.shape[2] == image_size and img.shape[3] == image_size, f'input image must have the correct image size {image_size}'

            indices = vae.get_codebook_indices(img)
            num_img_tokens = default(num_init_img_tokens, int(0.4375 * image_seq_len))  # OpenAI used 14 * 32 initial tokens to prime
            assert num_img_tokens < image_seq_len, 'number of initial image tokens for priming must be less than the total image token sequence length'

            indices = indices[:, :num_img_tokens]
            out = torch.cat((out, indices), dim = -1)

        prev_cache = None
        cache = {} if use_cache else None
        for cur_len in range(out.shape[1], total_len):
            is_image = cur_len >= text_seq_len

            text, image = out[:, :text_seq_len], out[:, text_seq_len:]
            # print(cur_len, text.shape, image.shape)

            if cond_scale != 1 and use_cache:
                # copy the cache state to infer from the same place twice
                prev_cache = cache.copy()

            logits = self(text, image, cache = cache)

            if cond_scale != 1:
                # discovery by Katherine Crowson
                # https://twitter.com/RiversHaveWings/status/1478093658716966912
                null_cond_logits = self(text, image, null_cond_prob = 1., cache = prev_cache)
                logits = null_cond_logits + (logits - null_cond_logits) * cond_scale

            logits = logits[:, -1, :]

            filtered_logits = top_k(logits, thres = filter_thres)
            sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

            sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
            out = torch.cat((out, sample[:, None]), dim=-1)

        text_seq = out[:, :text_seq_len]

        img_seq = out[:, -image_seq_len:]
        images = vae.decode(img_seq)

        if exists(clip):
            scores = clip(text_seq, images, return_loss = False)
            return images, scores

        return images

    def forward(
        self,
        text = None,
        image = None,
        pg_data = None,
        return_loss = False,
        null_cond_prob = 0.,
        cache = None,
        pg_train = True,
        pg_infer = False,
    ):
        if not pg_infer:
            assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
            batch, device, total_seq_len = text.shape[0], text.device, self.total_seq_len

            # randomly remove text condition with <null_cond_prob> probability

            if null_cond_prob > 0:
                null_mask = prob_mask_like((batch,), null_cond_prob, device = device)
                text *= rearrange(~null_mask, 'b -> b 1')

            # make sure padding in text tokens get unique padding token id

            text_range = torch.arange(self.text_seq_len, device = device) + (self.num_text_tokens - self.text_seq_len)
            text = torch.where(text == 0, text_range, text)

            # add <bos>

            text = F.pad(text, (1, 0), value = 0)

            tokens = self.text_emb(text)
            tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device))

            seq_len = tokens.shape[1]

            if exists(image) and not is_empty(image):
                is_raw_image = len(image.shape) == 3

                if is_raw_image:
                    # image_size = self.vae.image_size
                    # assert tuple(image.shape[1:]) == (3, image_size, image_size), f'invalid image of dimensions {image.shape} passed in during training'

                    image = self.vae.get_codebook_indices(image)

                image_len = image.shape[1]
                image_emb = self.image_emb(image)

                # image_emb += self.image_pos_emb(image_emb)

                tokens = torch.cat((tokens, image_emb), dim = 1)

                seq_len += image_len

            # when training, if the length exceeds the total text + image length
            # remove the last token, since it needs not to be trained

            if tokens.shape[1] > total_seq_len:
                seq_len -= 1
                tokens = tokens[:, :-1]

            if self.stable:
                alpha = 0.1
                tokens = tokens * alpha + tokens.detach() * (1 - alpha)

            if exists(cache) and cache.get('offset'):
                tokens = tokens[:, -1:]

            #tokens.shape: [24, 320, 512]
            out = self.transformer(tokens, cache=cache)
            #out [24, 320, 512]


            if self.stable:
                out = self.norm_by_max(out)

            # out.shape: [4, 273, 512]
            # logits.shape: [4, 273, 50176]
            logits = self.to_logits(out)

            # mask logits to make sure text predicts text (except last token), and image predicts image

            logits_mask = self.logits_mask[:, :seq_len]
            if exists(cache) and cache.get('offset'):
                logits_mask = logits_mask[:, -1:]
            max_neg_value = -torch.finfo(logits.dtype).max
            logits.masked_fill_(logits_mask, max_neg_value)

            if exists(cache):
                cache['offset'] = cache.get('offset', 0) + logits.shape[1]

            if not return_loss:
                return logits

            assert exists(image), 'when training, image must be supplied'

            offsetted_image = image + self.num_text_tokens
            labels = torch.cat((text[:, 1:], offsetted_image), dim = 1)

            logits = rearrange(logits, 'b n c -> b c n')

            loss_text = F.cross_entropy(logits[:, :, :self.text_seq_len], labels[:, :self.text_seq_len])
            loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:])

            loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)

        if pg_train and not pg_infer:
            pg_pts, pgms, _, params, _ = pg_data[0].cuda(), pg_data[1].cuda(), pg_data[2].cuda(), pg_data[3].cuda(), pg_data[4].cuda()
            # pg_pts = normalize_points_torch(pg_pts)
            
            pts_code = self.vae.get_codebook_indices(pg_pts)
            pts_code = F.pad(pts_code, (1, 0), value = 0)
            pts_emb = self.image_emb(pts_code)
            pg_seq_len = pts_code.shape[1]

            pg_code = self.pgvae.get_codebook_indices(pgms, params)
            
            #debug
            # out = self.pgvae.decode(pg_code)
            # id = 2
            # pgms_our = torch.argmax(out[id,:,:22].reshape(10,3,-1),-1)
            # pgms_in = pgms[id,:]
            
            pg_emb = self.pgm_emb(pg_code)
            pg_tokens = torch.cat((pts_emb, pg_emb), dim = 1)
            pg_seq_len += pg_emb.shape[1]
            
            if pg_seq_len > self.total_pg_seq_len:
                pg_seq_len -= 1
                pg_tokens = pg_tokens[:, :-1]

            pg_out = self.transformer(pg_tokens)
            # pg_out = self.transformer2(pg_tokens)
            pg_logits = self.to_pg_logits(pg_out)
            
            pg_logits_mask = self.pg_logits_mask[:, :pg_seq_len]
            max_neg_value = -torch.finfo(pg_logits.dtype).max
            pg_logits.masked_fill_(pg_logits_mask, max_neg_value)

            offsetted_pg_code = pg_code + self.num_image_tokens
            pg_labels = torch.cat((pts_code[:, 1:], offsetted_pg_code), dim = 1)

            pg_logits = rearrange(pg_logits, 'b n c -> b c n')

            loss_pts = F.cross_entropy(pg_logits[:, :, :self.image_seq_len], pg_labels[:, :self.image_seq_len])
            loss_pg = F.cross_entropy(pg_logits[:, :, self.image_seq_len:], pg_labels[:, self.image_seq_len:])

            loss = 1/2*loss + 1/2*((loss_pg + self.loss_img_weight *loss_pts)/ (self.loss_img_weight +1))
            print('total_loss:%.3f, loss_text:%.3f, loss_img:%d x %.3f, loss_pg:%.3f, loss_pts:%.3f'%(loss, loss_text, self.loss_img_weight, loss_img, loss_pg, loss_pts))


        return loss

    @torch.no_grad()
    @eval_decorator
    def generate_pgs(
        self,
        pg_pts,
        pg = None,
        *,
        filter_thres = 0.5,
        temperature = 1.
    ):
        image_seq_len, pg_seq_len, num_image_tokens = self.image_seq_len, self.pg_seq_len, self.num_image_tokens
        total_len = image_seq_len + pg_seq_len
        
        pts_code = self.vae.get_codebook_indices(pg_pts)
        out = pts_code

        for cur_len in range(out.shape[1], total_len):
            is_pg = cur_len >= image_seq_len

            pts_code, pg_code = out[:, :image_seq_len], out[:, image_seq_len:]
            
            pts_code = F.pad(pts_code, (1, 0), value = 0)
            pts_emb = self.image_emb(pts_code)
            pg_emb = self.pgm_emb(pg_code)
            ptspg_emb = torch.cat((pts_emb, pg_emb), dim = 1)
            cur_pg_seq_len = ptspg_emb.shape[1]
            if ptspg_emb.shape[1] > total_len:
                cur_pg_seq_len -= 1
                tokens = tokens[:, :-1]
            
            cur_out = self.transformer(ptspg_emb)
            cur_logits = self.to_pg_logits(cur_out)
            pg_logits_mask = self.pg_logits_mask[:, :cur_pg_seq_len]
            max_neg_value = -torch.finfo(cur_logits.dtype).max
            cur_logits.masked_fill_(pg_logits_mask, max_neg_value)

            logits = cur_logits[:, -1, :]

            filtered_logits = top_k(logits, thres = 0.9)
            sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

            sample -= (num_image_tokens if is_pg else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
            out = torch.cat((out, sample[:, None]), dim=-1)
        
        pg_seq = out[:, -pg_seq_len:]
        return self.pgvae.decode(pg_seq)